# -*- coding: utf-8 -*-
"""
Created on Wed Apr 12 14:22:48 2023

@author: lv
"""
import os
import torch
from models.TransformerQA import TransformerQA
from utils.JiebaTokenizer import JiebaTokenizer

tokenizer = JiebaTokenizer()

# 模型参数
vocab_size = tokenizer.vocab_size
embedding_dim = 96
hidden_dim = 1024
num_layers = 3
num_heads = 12
max_seq_len = 512
model_save_path = f'./run/{embedding_dim}_{hidden_dim}_{num_layers}_{num_heads}'

# 指定设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'

# 加载模型
model = TransformerQA(vocab_size, embedding_dim, hidden_dim, num_layers, num_heads, max_seq_len)
checkpoint = torch.load(os.path.join(model_save_path, 'best_model.pth'),map_location=torch.device(device))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

def encode_input(input_text,history_text):

    #最新的提问在末尾
    querystr = history_text + input_text

    #长度超过则截断，优先保留末尾
    if len(querystr) > max_seq_len:
        querystr = querystr[-max_seq_len:]
    # 将输入文本转换为模型需要的格式
    query = tokenizer.encode(querystr)
    return torch.tensor(query),querystr

history_text = ""

bool_history = False

# 生成对话
while True:
    input_text = input("你：")
    # 编码输入序列
    encoded_input,history_text = encode_input(input_text,history_text)
    encoded_input = encoded_input.unsqueeze(0).to(device)
    # 生成回答
    output_seq = model.generate_output_sequence(encoded_input, None)
    #print(output_seq)
    #print('gen:',output_seq.size(1))
    output_seq = output_seq.argmax(dim=2).squeeze(1)
    # 将输出序列转换为自然语言形式输出
    output_text = tokenizer.decode(output_seq.view(-1).tolist())
    output_text = "".join(output_text)
    history_text = history_text + output_text
    if bool_history == False:
        history_text = ""
        
    print("AI: " + output_text)

